# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to average values of variables in a list of checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import six
from six.moves import zip  # pylint: disable=redefined-builtin
import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("checkpoints", "",
                    "Comma-separated list of checkpoints to average.")
flags.DEFINE_integer("num_last_checkpoints", 0,
                     "Averages the last N saved checkpoints."
                     " If the checkpoints flag is set, this is ignored.")
flags.DEFINE_string("prefix", "",
                    "Prefix (e.g., directory) to append to each checkpoint.")
flags.DEFINE_string("output_path", "/tmp/averaged.ckpt",
                    "Path to output the averaged checkpoint to.")


def checkpoint_exists(path):
    return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or
            tf.gfile.Exists(path + ".index"))


def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.checkpoints:
        # Get the checkpoints list from flags and run some basic checks.
        checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")]
        checkpoints = [c for c in checkpoints if c]
        if not checkpoints:
            raise ValueError("No checkpoints provided for averaging.")
        if FLAGS.prefix:
            checkpoints = [FLAGS.prefix + c for c in checkpoints]
    else:
        assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model"
        assert FLAGS.prefix, ("Prefix must be provided when averaging last"
                              " N checkpoints")
        checkpoint_state = tf.train.get_checkpoint_state(
            os.path.dirname(FLAGS.prefix))
        # Checkpoints are ordered from oldest to newest.
        checkpoints = checkpoint_state.all_model_checkpoint_paths[
                      -FLAGS.num_last_checkpoints:]

    checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
    if not checkpoints:
        if FLAGS.checkpoints:
            raise ValueError(
                "None of the provided checkpoints exist. %s" % FLAGS.checkpoints)
        else:
            raise ValueError("Could not find checkpoints at %s" %
                             os.path.dirname(FLAGS.prefix))

    # Read variables from all checkpoints and average them.
    tf.logging.info("Reading variables and averaging checkpoints:")
    for c in checkpoints:
        tf.logging.info("%s ", c)
    var_list = tf.contrib.framework.list_variables(checkpoints[0])
    var_values, var_dtypes = {}, {}
    for (name, shape) in var_list:
        if not name.startswith("global_step"):
            var_values[name] = np.zeros(shape)
    for checkpoint in checkpoints:
        reader = tf.contrib.framework.load_checkpoint(checkpoint)
        for name in var_values:
            tensor = reader.get_tensor(name)
            var_dtypes[name] = tensor.dtype
            var_values[name] += tensor
        tf.logging.info("Read from checkpoint %s", checkpoint)
    for name in var_values:  # Average.
        var_values[name] /= len(checkpoints)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        tf_vars = [
            tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
            for v in var_values
        ]
    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
    global_step = tf.Variable(
        0, name="global_step", trainable=False, dtype=tf.int64)
    saver = tf.train.Saver(tf.all_variables())

    # Build a model consisting only of variables, set them to the average values.
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                               six.iteritems(var_values)):
            sess.run(assign_op, {p: value})
        # Use the built saver to save the averaged checkpoint.
        saver.save(sess, FLAGS.output_path, global_step=global_step)

    tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path)


if __name__ == "__main__":
    tf.app.run()
